__author__ = 'carlxie'

from layer.ConvLayer import ConvLayer
from layer.FullConLayer import FullConLayer
from layer.MaxPoolLayer import MaxPoolLayer
from layer.Layer import Layer
from layer.TanhLayer import TanhLayer
from layer.util import *
from PIL import Image

import numpy as np

class ConvNetwork():

    def __init__(self,config,eta):
        self.eta = eta
        self.layers = []
        preLayer = None
        for i in range(len(config)):
            cfg = config[i]
            if cfg["type"] == LayerType.CONV_TYPE:
                layer = ConvLayer(preLayer,cfg['kernelNum'],cfg['kernelSize'],cfg['stride'],cfg['padding'])
            elif cfg["type"] == LayerType.FULL_CONNECT_TYPE:
                layer = FullConLayer(preLayer,cfg['numNeurons'])
            elif cfg["type"] == LayerType.MAX_POOL_TYPE:
                layer = MaxPoolLayer(preLayer,cfg['kernelSize'],cfg['stride'])
            elif cfg["type"] == LayerType.TANH_TYPE:
                layer = TanhLayer(preLayer)
            else:
                layer = Layer(None,LayerType.INPUT_TYPE,cfg['width'],cfg['height'],cfg['depth'])
            self.layers.append(layer)
            preLayer = layer

    def feed_forward(self,x):
        inputLayer = self.layers[0]
        inputLayer.output = x
        for layer in self.layers:
            layer.forward_prop()

    def back_forward(self,x,y):
        outputLayer = self.layers[-1]
        outputLayer.delta = self.get_last_delta(x,y)
        for i in range(len(self.layers)-1,0,-1):
            self.layers[i].back_prop()

    def get_last_delta(self,x,y):
        self.feed_forward(x)
        outputLayer = self.layers[-1]
        diff = (vec_output(y,len(outputLayer.output[:,:,0])) - outputLayer.output[:,:,0]) * 2
        return - diff

    def train(self,training_data,T,mini_batch_size):
        for t in range(T):
            print "epoch ,,,,,, "+str(t)
            np.random.shuffle(training_data)
            mini_batches = [training_data[k:k+mini_batch_size]
                    for k in xrange(0, len(training_data), mini_batch_size)]
            for batch in mini_batches:
                loss = 0.0
                for data in batch:
                    x = data[:-1].reshape(28,28,1)
                    y = data[-1]
                    self.back_forward(x,y)
                    loss += self.cost(x,y)
               # print loss / mini_batch_size
                self.update_weights(self.eta[int(t/50)])


    def update_weights(self,eta):
        for layer in self.layers:
            layer.update_weights(eta)

    def cost(self,x,y):
        self.feed_forward(x)
        outputLayer = self.layers[-1]
        netOutput = outputLayer.output[:,0,0]
        expectOutput = vec_output(y,len(outputLayer.output[:,0,0]))[:,0]
        loss = self.square_cost(expectOutput,netOutput)
        return loss

    def square_cost(self,expect,real):
        return sum((expect - real) ** 2)

    def get_analytic_grads(self,x,y):
        self.back_forward(x,y)
        grads = []
        for layer in self.layers:
            if layer.layerType == LayerType.INPUT_TYPE or  layer.layerType == LayerType.MAX_POOL_TYPE:continue
            grads.append(layer.w_grads[0])

        for layer in self.layers:
            if layer.layerType == LayerType.INPUT_TYPE or layer.layerType == LayerType.MAX_POOL_TYPE:continue
            grads.append(layer.b_grads[0])
        return flat_list(grads)

    def set_weights_biases(self,weights,biases):
        index = 0
        for layer in self.layers:
            if layer.layerType == LayerType.INPUT_TYPE or layer.layerType == LayerType.MAX_POOL_TYPE:continue
            layer.weights = weights[index]
            layer.biases = biases[index]
            index = index + 1

    def cal_loss(self,x,y,w,b):
        self.set_weights_biases(w,b)
        return self.cost(x,y)

    def predict(self,x):
        self.feed_forward(x)
        val = np.argmax(self.layers[-1].output[:,0,0])
        #print val
        return val
    def eval(self,X,Y):
        error = 0.0
        for i in range(len(X)):
            if self.predict(X[i].reshape(28,28,1)) != Y[i]:
                error += 1
        return error / len(X)

    def evaluate(self, X, Y):
        return sum([self.predict(X[i].reshape(28,28,1)) == Y[i]
                    for i in range(len(X))]) / float(len(X))


def grad_check():
    config = [
            {
                "type":LayerType.INPUT_TYPE,
                "width":28,
                "height":28,
                "depth":1
            },
            {
                "type":LayerType.CONV_TYPE,
                "kernelNum":1,
                "kernelSize":3,
                "stride":1,
                "padding":0
            },
            {
                 "type":LayerType.MAX_POOL_TYPE,
                 "kernelSize":2,
                 "stride":2
            },
            {
                "type":LayerType.CONV_TYPE,
                "kernelNum":1,
                "kernelSize":2,
                "stride":1,
                "padding":0
            },
            {
                 "type":LayerType.MAX_POOL_TYPE,
                 "kernelSize":2,
                 "stride":2
            },
            {
                "type":LayerType.FULL_CONNECT_TYPE,
                "numNeurons":5
            },
            {
                "type":LayerType.FULL_CONNECT_TYPE,
                "numNeurons":10
            }
        ]
    images = np.load("train.dat.npy")
    data = images[0]
    x = data[:-1].reshape(28,28,1)
    y = data[-1]
    convNet = ConvNetwork(config,[0.03,0.01])

    ws = get_weights(convNet)
    bs = get_biases(convNet)
    weights = np.append(ws,bs)
    shapes = get_shapes(convNet)

    num_grads = compute_num_grads(convNet,x,y,shapes,weights)

    originWeights,originBiases = reconstruct(weights, shapes)
    convNet.set_weights_biases(originWeights,originBiases)
    analytic_grads = convNet.get_analytic_grads(x,y)
    print(num_grads)
    print(analytic_grads)

    diff = num_grads - analytic_grads
    num = abs(sum(diff))
    den = abs(sum(num_grads)) + abs(sum(analytic_grads))
    print num/den

def test_mnist():
    train_data = np.load("train.dat.npy")
    test_data = np.load("test.dat.npy")

    config = [
            {
                "type":LayerType.INPUT_TYPE,
                "width":28,
                "height":28,
                "depth":1
            },
            # {
            #     "type":LayerType.CONV_TYPE,
            #     "kernelNum":5,
            #     "kernelSize":5,
            #     "stride":1,
            #     "padding":0
            # },
            # {
            #     "type":LayerType.TANH_TYPE
            # },
            # {
            #      "type":LayerType.MAX_POOL_TYPE,
            #      "kernelSize":2,
            #      "stride":2
            # },
            # {
            #     "type":LayerType.CONV_TYPE,
            #     "kernelNum":16,
            #     "kernelSize":3,
            #     "stride":1,
            #     "padding":0
            # },
            # {
            #      "type":LayerType.MAX_POOL_TYPE,
            #      "kernelSize":2,
            #      "stride":2
            # },
            {
                "type":LayerType.FULL_CONNECT_TYPE,
                "numNeurons":30
            },
            {
                "type":LayerType.TANH_TYPE
            },

            {
                "type":LayerType.FULL_CONNECT_TYPE,
                "numNeurons":10
            },
            {
                "type":LayerType.TANH_TYPE
            }
        ]

    small_train_data = train_data[:300]
    small_test_data = test_data[300:500]

    convNet = ConvNetwork(config,[0.03,0.01])
    convNet.train(small_train_data,50,10)
    # ws = get_weights(convNet)
    # bs = get_biases(convNet)
    # weights = np.append(ws,bs)
    # np.save("trained_weights",weights)

    # weights = np.load("trained_weights.npy")
    # w,b = reconstruct(weights,get_shapes(convNet))
    # convNet.set_weights_biases(w,b)
    e_in = convNet.eval(small_train_data[:,:-1],small_train_data[:,-1])
    print "e_in "+str(e_in)

    error = convNet.eval(small_test_data[:,:-1],small_test_data[:,-1])
    print "e_out: "+str(error)

def showImage():
    train_data = np.load("train.dat.npy")
    imdata = train_data[0,:-1]
    imdata = imdata * 255
    img = Image.fromarray(imdata.reshape(28,28))
    img.show()

if __name__ == "__main__":
    #showImage()
    #grad_check()
    test_mnist()


